VQVAE means Vector Quantized Variational AutoEncoder. Let’s break down the name.
Auto-Encoder
This model takes the input, passes it through smaller layers (Encodes) and tries to, from that small layer, reproduce the input (Decodes). It thus needs to learn what are the most important features to keep ! If the input was images of cats, the model would keep information about the color, way it’s looking, thiccness of the cat.
Variational
Instead of having to encode inputs onto points in space we map them onto distributions. In our case, we simply want to make sure that on average, the space we’re mapping to is uniform. This means all codes will be used.
Vector Quantized
We want to define a set of points in the smaller representation, the latent. These points will be called codes and be part of a codebook. Think of the codebook as the set of possible words the encoder can use to describe what it sees to the decoder. Obviously, the more codes we have, the more information the encoder will be able to pass per code sent.
Our VQVAE is going to find the best codes that describe speech. It’ll take in special images called Mel-Frequency Spectrograms, which is basically a way to represent human speech.
Since our final goal is to recreate a 1 to 1 version of the VQVAE used in XTTS, we’ll hardcode a lot of things to minimize issues.
We have 4 parts to code: An ResBlock, Encoder, a Decoder and a Quantizer. Let’s get into it!
Importing JAX and Equinox, a library that helps with writing neural networks.
import jaximport jax.numpy as jnpimport equinox as eqximport equinox.nn as nnimport typing as tp
3.1 ResBlocks
The role of this Block is to mainly exchange information between the various parts of each channel, but the input is added at the end. This allows our network to basically, if suitable, simply put all weights to zero and be “Shallower” basically our network decides how many layers it needs !
class ResBlock(eqx.Module): conv1: nn.Conv1d conv2: nn.Conv1d conv3: nn.Conv1d act: tp.Callable = eqx.static_field()def__init__(self, dim: int, activation = jax.nn.relu, key=None): key1, key2, key3 = jax.random.split(key, 3)self.conv1 = nn.Conv1d(dim, dim, kernel_size=3, padding="SAME", key=key1)self.conv2 = nn.Conv1d(dim, dim, kernel_size=3, padding="SAME", key=key2)self.conv3 = nn.Conv1d(dim, dim, kernel_size=1, padding="SAME", key=key3)self.act = activationdef__call__(self, x): y = x y =self.conv1(y) y = jax.nn.relu(y) y =self.conv2(y) y = jax.nn.relu(y) y =self.conv3(y) y = y + xreturn y
3.2 Encoder
Moving onto the Encoder. It has layers that take in the input, and slowly compress it by lowering the image dimensions and increasing the amounts of channels, much like ResNet:
class Encoder(eqx.Module): conv1: nn.Conv1d conv2: nn.Conv1d conv3: nn.Conv1d res1: ResBlock res2: ResBlock res3: ResBlockdef__init__(self, hidden_dim: int=1024, codebook_dim: int=512, key=None): key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6)self.conv1 = nn.Conv1d(in_channels=80, out_channels=512, kernel_size=3, stride=2, padding="SAME", key=key1)self.conv2 = nn.Conv1d(in_channels=512, out_channels=hidden_dim, kernel_size=3, stride=2, padding="SAME", key=key2)self.res1 = ResBlock(dim=hidden_dim, key=key3)self.res2 = ResBlock(dim=hidden_dim, key=key4)self.res3 = ResBlock(dim=hidden_dim, key=key5)self.conv3 = nn.Conv1d(in_channels=hidden_dim, out_channels=codebook_dim, kernel_size=1, stride=1, padding="SAME", key=key6)def__call__(self, x): y =self.conv1(x) y = jax.nn.relu(y) y =self.conv2(y) y = jax.nn.relu(y) y =self.res1(y) y =self.res2(y) y =self.res3(y) y =self.conv3(y)return y
3.3 Decoder
We can now implement the decoder. Instead of using ConvTranspose1d here XTTS uses upsampling and interpolation between points, replacing the striding that would usually happen. We implement it just below Section 3.3.1
class Decoder(eqx.Module): conv1: nn.Conv1d conv2: UpsampledConv conv3: UpsampledConv conv4: nn.Conv1d res1: ResBlock res2: ResBlock res3: ResBlockdef__init__(self, hidden_dim: int=1024, codebook_dim: int=512, key=None): key1, key2, key3, key4, key5, key6, key7 = jax.random.split(key, 7)self.conv1 = nn.Conv1d(in_channels=codebook_dim, out_channels=hidden_dim, kernel_size=1, stride=1, padding="SAME", key=key1)self.res1 = ResBlock(dim=hidden_dim, key=key2)self.res2 = ResBlock(dim=hidden_dim, key=key3)self.res3 = ResBlock(dim=hidden_dim, key=key4)self.conv2 = UpsampledConv(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=2, padding="SAME", key=key5)self.conv3 = UpsampledConv(in_channels=hidden_dim, out_channels=512, kernel_size=3, stride=2, padding="SAME", key=key6)self.conv4 = nn.Conv1d(in_channels=512, out_channels=80, kernel_size=1, stride=1, padding="SAME", key=key7)def__call__(self, x): y =self.conv1(x) y =self.res1(y) y =self.res2(y) y =self.res3(y) y =self.conv2(y) y = jax.nn.relu(y) y =self.conv3(y) y = jax.nn.relu(y) y =self.conv4(y)return y
3.3.1 UpsampledConv
Before we move onto the decoder we have to define a special layer that replaces the ConvTranspose that we would normally use. I admit I still am not sure of the use of this.
Their code for this function:
class UpsampledConv(nn.Module):def__init__(self, conv, *args, **kwargs):super().__init__()assert"stride"in kwargs.keys()self.stride = kwargs["stride"]del kwargs["stride"]self.conv = conv(*args, **kwargs)def forward(self, x): up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest")returnself.conv(up)
We can execute the torch version to check how it works and then compare to our solution
After having implemented various VQVAEs, what strikes me in this one is that the ResBlocks are all seperated from the various convolutional stages, and the lack of normalisation between layers. Moving onto the crux of the matter, the Quantizer !
3.4 Quantizer
So our encoder spits out a certain number of vectors (The number of channels) with the codebook dim. These vectors are mapped to their nearest neighbors, and these are then transmitted to the decoder. To stabalise the model, we’re going to also add some exponential moving average to the codebook, as well as normalize things so as to keep things from exploding or minimizing. Exponential moving average meaning that we basically mostly keep what we currently have and add a little of the new stuff instead of fully changing things every time.
We can’t update the codebook here though, as this is immutable stuff. We need to update it between each training instead.
class Quantizer(eqx.Module): K: int= eqx.static_field() D: int= eqx.static_field() codebook: jax.Array codebook_avg: jax.Array cluster_size: jax.Array decay: float= eqx.static_field() eps: float= eqx.static_field()def__init__(self, num_vecs: int=1024, num_dims: int=512, decay: float=0.99, eps: float=1e-5, key=None):self.K = num_vecsself.D = num_dimsself.decay = decayself.eps = eps# Init a matrix of vectors that will move with timeself.codebook = jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform")(key, (num_vecs, num_dims))self.codebook_avg = jnp.copy(self.codebook)self.cluster_size = jnp.zeros(num_vecs)def__call__(self, x):# x has N vectors of the codebook dimension. We calculate the nearest neighbors and output those instead flatten = jax.numpy.reshape(x, (-1, self.D)) a_squared = jnp.sum(flatten**2, axis=-1, keepdims=True) b_squared = jnp.transpose(jnp.sum(self.codebook**2, axis=-1, keepdims=True)) distance = a_squared + b_squared -2*jnp.matmul(flatten, jnp.transpose(self.codebook)) codebook_indices = jnp.argmin(distance, axis=-1) z_q =self.codebook[codebook_indices]# Straight-through estimator z_q = flatten + jax.lax.stop_gradient(z_q - flatten) z_q = jax.numpy.permute_dims(z_q, (1, 0))return z_q, self.codebook_updates(flatten, codebook_indices)def codebook_updates(self, flatten, codebook_indices):# Calculate the usage of various codes. codebook_onehot = jax.nn.one_hot(codebook_indices, self.K) codebook_onehot_sum = jnp.sum(codebook_onehot, axis=0) codebook_sum = jnp.dot(flatten.T, codebook_onehot)# We've just weighed the codebook vectors.# Basically count on average how many codes we're using new_cluster_size =self.decay *self.cluster_size + (1-self.decay) * codebook_onehot_sum# Where is the average embedding at ? new_codebook_avg =self.decay *self.codebook_avg + (1-self.decay) * codebook_sum.T n = jnp.sum(new_cluster_size) # Over the total embeddings used new_cluster_size = (new_cluster_size +self.eps) / (n +self.K *self.eps) * n new_codebook =self.codebook_avg / new_cluster_size[:, None] updates = (new_cluster_size, new_codebook_avg, new_codebook)return updates, codebook_indices
We can visualize vectors being mapped to the neasest codes below, where light blue vectors come in, and are snapped to the closest red vectors. TODO fix the bug
We can see that we go from a \([80 \times X]\) long spectrogram and go down to a \([512 \times \frac{X}{4}]\) image, which to be fair seems kind of dumb because we actually don’t loose any information like this. We could have a 1 to 1 reproduction of the image… was it not for the quantizer in the middle that forces us to assign the image to actually, only \(\frac{X}{4}\) vectors ! Whatever the dimension of these vectors, the information we encode can this be counted in \(log_2(x/4)\) bits which makes it quite small 🤏😎
4 Training
Figure 2: How not to train things
4.1 Codebook special thingies
There are a few things that we need to do. First we need to write a function that will update our model based on the passed \({Codebook}\). It should not only update the various values we’re keeping track of, but since we’re using a VQVAE and that they impose that the distribution is uniform this means that we need to yeet codes that are being used too often.
Note
VAEs work because we impose that they don’t just encode important information but they encode it in a uniform way, so that vectors are well spread out and not clustered. Normally you need to add a loss term checking at what point the embeddings follow that spread (usually a guassian distribution). In VQ-VAEs we impose the distribution to be uniform, making this term constant and thus we don’t have to add it in the loss
Figure 3: With replacing codebook outliers
Figure 4: Without replacement
import optaxfrom tensorboardX import SummaryWriterdef update_codebook_ema(model: VQVAE, updates: tuple, codebook_indices, key=None): avg_updates = jax.tree.map(lambda x: jax.numpy.mean(x, axis=0), updates)# Calculate which codes are too often used and yeet them. Prior is uniform. h = jnp.histogram(codebook_indices, bins=model.quantizer.K, range=(0, model.quantizer.K))[0] /len(codebook_indices) part_that_should_be =1/model.quantizer.K mask = (h >2* part_that_should_be) | (h <0.5* part_that_should_be) rand_embed = jax.random.normal(key, (model.quantizer.K, model.quantizer.D)) * mask[:, None] avg_updates = (avg_updates[0], avg_updates[1], jnp.where(mask[:, None], rand_embed, avg_updates[2])) where =lambda q: (q.quantizer.cluster_size, q.quantizer.codebook_avg, q.quantizer.codebook)# Update the codebook and other trackers. model = eqx.tree_at(where, model, avg_updates)return model
4.2 Losses and gradients
Nearly there ! 😮💨 We can now write out our two “classic” functions, that will 1. Calculate at what point our boi outputs garbage or not, by comparing outputs to inputs using Mean Square Error, and at what point the encoder is outputing vectors that are close to the codes available.
@eqx.filter_jit@eqx.filter_value_and_grad(has_aux=True)def calculate_losses(model, x): z_e, z_q, codebook_updates, y = jax.vmap(model)(x)# Are the inputs and outputs close? reconstruct_loss = jnp.mean(jnp.linalg.norm((x - y), ord=2, axis=(1,2)))# Are the output vectors z_e close to the codes z_q ? commit_loss = jnp.mean(jnp.linalg.norm(z_e - jax.lax.stop_gradient(z_q), ord=2, axis=(1,2))) total_loss = reconstruct_loss + commit_lossreturn total_loss, (reconstruct_loss, commit_loss, codebook_updates, y)@eqx.filter_jitdef make_step(model, optimizer, opt_state, x, key): (total_loss, (reconstruct_loss, commit_loss, codebook_updates, y)), grads = calculate_losses(model, x) updates, opt_state = optimizer.update(grads, opt_state, model) model = eqx.apply_updates(model, updates) model = update_codebook_ema(model, codebook_updates[0], codebook_updates[1], key)return model, opt_state, total_loss, reconstruct_loss, commit_loss, codebook_updates, y
4.3 Preparing the data
Let’s download some data and make runs through to see if it can learn from it !
XTTS has this function that seems to transform the incoming wav files into nice mel_spectrograms. To optimize the time spent loading the data, we’ll transform all the data into input arrays first, and then during the run load from those instead.
Now that we have a folder full of mel spectrograms ready to be fed into our model, we can start the training ! Below, we do multiple things:
Initialize the model, the optimizer that will nugde it based on the losses our function returns, logging with tensorboard and saving the model every epoch.
With a little bit of patience, we can see the output image start to resemble more and more the input one !
Result after just 10 minutes of training on a NVIDIA L40
In the tensorboard we can also see that the codewords are progressively all used by the same amount during training. It’s beautiful 🥰
Codebooks slowly all being used uniformily.
This concludes this chapter, if you have any questions or remarks feel free to reach out to me ! (sxyBoi?) on Telegram 😅
5 References
Casanova, Edresson, Kelly Davis, Eren Gölge, Görkem Göknar, Iulian Gulea, Logan Hart, Aya Aljafari, et al. 2024. “XTTS: A Massively Multilingual Zero-Shot Text-to-Speech Model.”https://arxiv.org/abs/2406.04904.
Défossez, Alexandre, Jade Copet, Gabriel Synnaeve, and Yossi Adi. 2022. “High Fidelity Neural Audio Compression.”https://arxiv.org/abs/2210.13438.
Dumakude, A., and A. E. Ezugwu. 2023. “Image from "Automated COVID-19 Detection with Convolutional Neural Networks".”Scientific Reports 13: 10607. https://doi.org/10.1038/s41598-023-37743-4.